# -*- coding: utf-8 -*-
from __future__ import absolute_import
from keras import backend as K
from keras.layers.core import Layer,MaskedLayer
from keras import activations, initializations, regularizers, constraints
import numpy as np
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD
from keras.utils import np_utils
import tensorflow as tf

def Deconv_model(img_channels,img_rows,img_cols):
    model = Sequential()
    model.add(Convolution2D(16, 3, 3, b_constraint=0,
                            input_shape=(img_channels, img_rows, img_cols)))
    model.add(Activation('relu'))
    model.add(Convolution2D(32, 3, 3, b_constraint=0))
    model.add(Activation('relu'))
    model.add(Dropout(0.25))
    model.add(Convolution2D_Transpose(64, 3, 3, b_constraint=0))
    model.add(Activation('relu'))
    model.add(Dropout(0.25))
    model.add(Convolution2D_Transpose(32, 3, 3,b_constraint=0))
    model.add(Activation('linear'))
    return model


def conv_output_length(input_length, filter_size, border_mode, stride):
    if input_length is None:
        return None
    assert border_mode in {'same', 'valid'}
    if border_mode == 'same':
        output_length = input_length
    elif border_mode == 'valid':
        output_length = input_length - filter_size + 1
    return (output_length + stride - 1) // stride

class Convolution2D_Transpose(Layer):
    input_ndim = 4

    def __init__(self, nb_filter, nb_row, nb_col,
                 init='glorot_uniform', activation='linear', weights=None,
                 border_mode='valid', subsample=(1, 1),deconv_shape = (),
                 W_regularizer=None, b_regularizer=None, activity_regularizer=None,
                 W_constraint=None, b_constraint=None, **kwargs):

        if border_mode not in {'valid', 'same'}:
            raise Exception('Invalid border mode for Convolution2D:', border_mode)
        self.nb_filter = nb_filter
        self.nb_row = nb_row
        self.nb_col = nb_col
        self.deconv_shape = deconv_shape
        self.init = initializations.get(init)
        self.activation = activations.get(activation)
        assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}'
        self.border_mode = border_mode
        self.subsample = tuple(subsample)

        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)
        self.constraints = [self.W_constraint, self.b_constraint]

        self.initial_weights = weights
        self.input = K.placeholder(ndim=4)
        super(Convolution2D_Transpose, self).__init__(**kwargs)

    def build(self):
        stack_size = self.input_shape[1]
        #self.W_shape = (self.nb_filter, stack_size, self.nb_row, self.nb_col)
        self.W_shape = ( self.nb_col, self.nb_row, stack_size, self.nb_filter)
        print self.W_shape
        self.W = self.init(self.W_shape)
        self.b = K.zeros((self.nb_filter,))
        self.trainable_weights = [self.W, self.b]
        self.regularizers = []

        if self.W_regularizer:
            self.W_regularizer.set_param(self.W)
            self.regularizers.append(self.W_regularizer)

        if self.b_regularizer:
            self.b_regularizer.set_param(self.b)
            self.regularizers.append(self.b_regularizer)

        if self.activity_regularizer:
            self.activity_regularizer.set_layer(self)
            self.regularizers.append(self.activity_regularizer)

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

    @property
    def output_shape(self):
        input_shape = self.input_shape
        rows = input_shape[2]
        cols = input_shape[3]
        self.deconv_shape = (self.nb_filter, input_shape[0], rows+2, cols+2)
        return self.deconv_shape

    def get_output(self, train=False):
        X = self.get_input(train)
        X = K.permute_dimensions(X,(0,2,3,1))
        #batch_size = tf.shape(X)[0]
        #deconv_shape = tf.pack([batch_size, 40, 40, 32])
        print 'deconv_shape: {0}'.format(self.deconv_shape)
        print 'value: {0}'.format(X)
        print 'W shape: {0}'.format(K.eval(self.W).shape)
        conv_out = tf.nn.conv2d_transpose(X, self.W, strides=self.subsample,
                            padding='VALID',
                            output_shape=self.deconv_shape)

        output = conv_out + K.reshape(self.b, (1, 1, 1,self.nb_filter))
        return output

    def get_config(self):
        config = {'name': self.__class__.__name__,
                  'nb_filter': self.nb_filter,
                  'nb_row': self.nb_row,
                  'nb_col': self.nb_col,
                  'init': self.init.__name__,
                  'activation': self.activation.__name__,
                  'border_mode': self.border_mode,
                  'subsample': self.subsample,
                  'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
                  'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None,
                  'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
                  'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
                  'b_constraint': self.b_constraint.get_config() if self.b_constraint else None}
        base_config = super(Convolution2D_Transpose, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))



